{
 "cells": [
  {
   "cell_type": "markdown",
   "source": [
    "# Pytorch入门实战（2）-使用BP神经网络实现MNIST手写数字识别"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%% md\n"
    }
   }
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "collapsed": true,
    "pycharm": {
     "name": "#%% md\n"
    }
   },
   "source": [
    "# 涉及知识点\n",
    "\n",
    "[Pytorch nn.Module的基本使用](https://blog.csdn.net/zhaohongfei_358/article/details/122797244)\n",
    "\n",
    "[Pytorch nn.Linear的基本用法](https://blog.csdn.net/zhaohongfei_358/article/details/122797190)\n",
    "\n",
    "[PytorchVision Transforms的基本使用](https://blog.csdn.net/zhaohongfei_358/article/details/122799782)\n",
    "\n",
    "[Pytorch中DataLoader的基本用法](https://blog.csdn.net/zhaohongfei_358/article/details/122742656)\n",
    "\n",
    "[Pytorch详解NLLLoss和CrossEntropyLoss](https://blog.csdn.net/qq_22210253/article/details/85229988)\n",
    "\n",
    "[如何确定神经网络的层数和隐藏层神经元数量](https://zhuanlan.zhihu.com/p/100419971)"
   ]
  },
  {
   "cell_type": "markdown",
   "source": [
    "# 本文内容\n",
    "\n",
    "本文将会使用BP神经网络（就是最普通的神经网络）实现一个MNIST手写数据集的实现。话不多说，直接开始。\n",
    "\n",
    "本文所使用到的环境如下:\n",
    "\n",
    "```\n",
    "python==3.8.5\n",
    "torch==1.10.2\n",
    "torchvision==0.11.3\n",
    "matplotlib==3.2.2\n",
    "```\n",
    "\n",
    "首先先导入需要的包:"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%% md\n"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "outputs": [],
   "source": [
    "import os\n",
    "import torch\n",
    "import matplotlib.pyplot as plt\n",
    "from time import time\n",
    "from torchvision import datasets, transforms\n",
    "from torch import nn, optim"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   }
  },
  {
   "cell_type": "markdown",
   "source": [
    "定义transform对象，其定义了数据集中的图片应该做怎样的处理："
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%% md\n"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "outputs": [],
   "source": [
    "transform = transforms.Compose([transforms.ToTensor(),\n",
    "                                transforms.Normalize((0.5,), (0.5,)),])"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   }
  },
  {
   "cell_type": "markdown",
   "source": [
    "加载和下载训练数据集，这里使用pytorch提供的API进行下载。如果你下载不下来，可以使用[百度网盘链接](https://pan.baidu.com/s/1NmxIlPhaeKSz_kFwCOn6rA?pwd=6hfa)进行下载，然后解压即可。"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%% md\n"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "outputs": [
    {
     "data": {
      "text/plain": "Dataset MNIST\n    Number of datapoints: 60000\n    Root location: train_set\n    Split: Train\n    StandardTransform\nTransform: Compose(\n               ToTensor()\n               Normalize(mean=(0.5,), std=(0.5,))\n           )"
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "train_set = datasets.MNIST('train_set', # 下载到该文件夹下\n",
    "                          download=not os.path.exists('train_set'), # 是否下载，如果下载过，则不重复下载\n",
    "                          train=True, # 是否为训练集\n",
    "                          transform=transform # 要对图片做的transform\n",
    "                         )\n",
    "train_set"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   }
  },
  {
   "cell_type": "markdown",
   "source": [
    "等待一段时间下载成功后，可以看到训练集中一共有6w个数据，接下来下载测试数据集："
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%% md\n"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz\n",
      "Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to test_set\\MNIST\\raw\\train-images-idx3-ubyte.gz\n"
     ]
    },
    {
     "data": {
      "text/plain": "  0%|          | 0/9912422 [00:00<?, ?it/s]",
      "application/vnd.jupyter.widget-view+json": {
       "version_major": 2,
       "version_minor": 0,
       "model_id": "b1a7e1c62c8b45479ef7003bd82b6c67"
      }
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Extracting test_set\\MNIST\\raw\\train-images-idx3-ubyte.gz to test_set\\MNIST\\raw\n",
      "\n",
      "Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz\n",
      "Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to test_set\\MNIST\\raw\\train-labels-idx1-ubyte.gz\n"
     ]
    },
    {
     "data": {
      "text/plain": "  0%|          | 0/28881 [00:00<?, ?it/s]",
      "application/vnd.jupyter.widget-view+json": {
       "version_major": 2,
       "version_minor": 0,
       "model_id": "1f24ed72115e43a0bf2401eb904345e7"
      }
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Extracting test_set\\MNIST\\raw\\train-labels-idx1-ubyte.gz to test_set\\MNIST\\raw\n",
      "\n",
      "Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz\n",
      "Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to test_set\\MNIST\\raw\\t10k-images-idx3-ubyte.gz\n"
     ]
    },
    {
     "data": {
      "text/plain": "  0%|          | 0/1648877 [00:00<?, ?it/s]",
      "application/vnd.jupyter.widget-view+json": {
       "version_major": 2,
       "version_minor": 0,
       "model_id": "60b31908556446d2b9352c3cee8c556e"
      }
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Extracting test_set\\MNIST\\raw\\t10k-images-idx3-ubyte.gz to test_set\\MNIST\\raw\n",
      "\n",
      "Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz\n",
      "Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to test_set\\MNIST\\raw\\t10k-labels-idx1-ubyte.gz\n"
     ]
    },
    {
     "data": {
      "text/plain": "  0%|          | 0/4542 [00:00<?, ?it/s]",
      "application/vnd.jupyter.widget-view+json": {
       "version_major": 2,
       "version_minor": 0,
       "model_id": "f96fb86f6acd476e9f2dd6916882da5a"
      }
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Extracting test_set\\MNIST\\raw\\t10k-labels-idx1-ubyte.gz to test_set\\MNIST\\raw\n",
      "\n"
     ]
    },
    {
     "data": {
      "text/plain": "Dataset MNIST\n    Number of datapoints: 10000\n    Root location: test_set\n    Split: Test\n    StandardTransform\nTransform: Compose(\n               ToTensor()\n               Normalize(mean=(0.5,), std=(0.5,))\n           )"
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "test_set = datasets.MNIST('test_set',\n",
    "                        download=not os.path.exists('test_set'),\n",
    "                        train=False,\n",
    "                        transform=transform\n",
    "                       )\n",
    "test_set"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   }
  },
  {
   "cell_type": "markdown",
   "source": [
    "测试数据集包含1w条数据\n",
    "\n",
    "接下来构建训练数据集和测试数据集的DataLoader对象："
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%% md\n"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "torch.Size([64, 1, 28, 28])\n",
      "torch.Size([64])\n"
     ]
    }
   ],
   "source": [
    "train_loader = torch.utils.data.DataLoader(train_set, batch_size=64, shuffle=True)\n",
    "test_loader = torch.utils.data.DataLoader(test_set, batch_size=64, shuffle=True)\n",
    "\n",
    "dataiter = iter(train_loader)\n",
    "images, labels = dataiter.next()\n",
    "\n",
    "print(images.shape)\n",
    "print(labels.shape)"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   }
  },
  {
   "cell_type": "markdown",
   "source": [
    "在上面，我们将其分成64个一组的图片，每个图片只有一个通道（灰度图），大小为28x28。抽一张绘制一下："
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%% md\n"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "outputs": [
    {
     "data": {
      "text/plain": "<Figure size 432x288 with 1 Axes>",
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAPsAAAD4CAYAAAAq5pAIAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjMuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8vihELAAAACXBIWXMAAAsTAAALEwEAmpwYAAAOJ0lEQVR4nO3df6xU9Z3G8edZrP4BDQHvFQkF6TYaNRuXkhFNNA2bSv2FwZp0U/6oGM1SE1RqatC4RjRqghtbc42GeF1/4EqpNZSIBt0SrD/4p3E0LOISlUVoQbxcYgw0QbvgZ/+4h80V73znMnPmB3zfr+RmZs4z554PAw9n7py5cxwRAnDi+7tODwCgPSg7kAnKDmSCsgOZoOxAJk5q58Z6enpi+vTp7dwkkJUdO3Zo3759Hilrquy2L5PUJ2mMpH+PiGWp+0+fPl3VarWZTQJIqFQqNbOGn8bbHiPpMUmXSzpX0nzb5zb6/QC0VjM/s8+StC0itkfE3yT9VtK8csYCULZmyj5F0l+G3d5VLPsa2wttV21XBwcHm9gcgGY0U/aRXgT4xntvI6I/IioRUent7W1icwCa0UzZd0maOuz2dyR90tw4AFqlmbK/LelM29+1fbKkn0paW85YAMrW8KG3iDhk+yZJ/6mhQ29PRcT7pU0GoFRNHWePiHWS1pU0C4AW4u2yQCYoO5AJyg5kgrIDmaDsQCYoO5AJyg5kgrIDmaDsQCYoO5AJyg5kgrIDmaDsQCYoO5AJyg5kgrIDmaDsQCYoO5AJyg5kgrIDmaDsQCYoO5AJyg5kgrIDmaDsQCYoO5AJyg5kgrIDmaDsQCYoO5CJpk7ZbHuHpAOSDks6FBGVMoYCUL6myl74p4jYV8L3AdBCPI0HMtFs2UPSH2y/Y3vhSHewvdB21XZ1cHCwyc0BaFSzZb8oImZKulzSIts/OPoOEdEfEZWIqPT29ja5OQCNaqrsEfFJcblX0hpJs8oYCkD5Gi677bG2v33kuqQfSdpS1mAAytXMq/GTJK2xfeT7/CYiXi1lKgCla7jsEbFd0j+WOAuAFuLQG5AJyg5kgrIDmaDsQCYoO5CJMn4RBuhKW7bUftvHpk2bkuvefPPNyfzw4cPJfP/+/cm8E9izA5mg7EAmKDuQCcoOZIKyA5mg7EAmKDuQCY6zn+Bee+21ZF6tVpP5kiVLyhznaz799NNk/sEHHyTz559/PpmvXLmyZtbscfDzzz8/me/evTuZT5kypantN4I9O5AJyg5kgrIDmaDsQCYoO5AJyg5kgrIDmeA4+wngmWeeqZk99NBDyXXrHe9dtGhRMn/llVeS+VtvvVUze/bZZ5Prfv7558m8lcaPH5/MV69encw7cRy9HvbsQCYoO5AJyg5kgrIDmaDsQCYoO5AJyg5kguPsx4Gnn346mS9evLhmduDAgeS627dvT+aVSiWZf/zxx8n8yy+/TOYpp59+ejIfGBhI5hFRM+vp6Umuu27dumQ+derUZN6N6u7ZbT9le6/tLcOWTbS93vZHxeWE1o4JoFmjeRr/jKTLjlp2h6QNEXGmpA3FbQBdrG7ZI+JNSZ8dtXiepBXF9RWSri53LABla/QFukkRsUeSisvTat3R9kLbVdvVwcHBBjcHoFktfzU+IvojohIRld7e3lZvDkANjZZ9wPZkSSou95Y3EoBWaLTsayUtKK4vkPRiOeMAaJW6x9ltr5I0W1KP7V2SlkpaJul3tm+Q9GdJP2nlkMe7Q4cOJfPrrrsumb/00kvJvN6x9JSDBw8m823btiXzU089NZlfeeWVNbMbb7wxue769euT+V133ZXMJ06cWDNbs2ZNct16nwt/PKpb9oiYXyP6YcmzAGgh3i4LZIKyA5mg7EAmKDuQCcoOZIJfcW2Da6+9NpmvWrWqZds+44wzkvntt9+ezGfOnJnML7jggmOe6Yjly5cn876+vmQ+adKkZP7EE0/UzOr9uU5E7NmBTFB2IBOUHcgEZQcyQdmBTFB2IBOUHcgEx9nb4Kyzzmpq/dmzZyfza665pmY2b9685LrTpk1rZKRRS30k87333ptcd+/e9GeiPPjgg8l8zpw5NbNTTjklue6JiD07kAnKDmSCsgOZoOxAJig7kAnKDmSCsgOZcOq0tmWrVCpRrVbbtr1uUe8xPnz4cDIfM2ZMMrd9zDOVpd7f59y5c2tm9U65vGzZsmS+ZMmSZN7Jx6VTKpWKqtXqiH9w9uxAJig7kAnKDmSCsgOZoOxAJig7kAnKDmSC32dvg3rHe086qXv/Gnbv3p3Mb7vttmSeOpZ+4YUXJte99dZbk3mOx9GbUXfPbvsp23ttbxm27B7bu21vKr6uaO2YAJo1mqfxz0i6bITlD0fEjOKr9seRAOgKdcseEW9K+qwNswBooWZeoLvJ9ubiaf6EWneyvdB21XZ1cHCwic0BaEajZV8u6XuSZkjaI+lXte4YEf0RUYmISm9vb4ObA9CshsoeEQMRcTgivpL0hKRZ5Y4FoGwNld325GE3fyxpS637AugOdQ/w2l4labakHtu7JC2VNNv2DEkhaYekn7duRLTSzp07k/n999+fzN94441knvrM+xdeeCG57sknn5zMcWzqlj0i5o+w+MkWzAKghXi7LJAJyg5kgrIDmaDsQCYoO5CJ7v3dSpTi4MGDyfyWW25J5mvXrk3m48ePT+Z9fX01s56enuS6KBd7diATlB3IBGUHMkHZgUxQdiATlB3IBGUHMsFx9hPco48+mszrHUev9zHXS5cuTebnnXdeMkf7sGcHMkHZgUxQdiATlB3IBGUHMkHZgUxQdiATHGc/DkREMn/99ddrZkuWLGlq23Pnzk3m9U6rjO7Bnh3IBGUHMkHZgUxQdiATlB3IBGUHMkHZgUxwnP04sH79+mR+6aWXNvy9zz777GT+3HPPNfy90V3q7tltT7X9R9tbbb9ve3GxfKLt9bY/Ki4ntH5cAI0azdP4Q5J+GRHnSLpQ0iLb50q6Q9KGiDhT0obiNoAuVbfsEbEnIt4trh+QtFXSFEnzJK0o7rZC0tUtmhFACY7pBTrb0yV9X9KfJE2KiD3S0H8Ikk6rsc5C21Xb1cHBwSbHBdCoUZfd9jhJqyX9IiL2j3a9iOiPiEpEVHp7exuZEUAJRlV229/SUNFXRsTvi8UDticX+WRJe1szIoAy1D30ZtuSnpS0NSJ+PSxaK2mBpGXF5YstmTADH374YTK/7777Gv7el1xySTJ/+OGHk/nYsWMb3ja6y2iOs18k6WeS3rO9qVh2p4ZK/jvbN0j6s6SftGRCAKWoW/aI2CjJNeIfljsOgFbh7bJAJig7kAnKDmSCsgOZoOxAJvgV1zao9zbhBx54IJlv3Lgxmc+YMaNm9sgjjyTXPeecc5I5Thzs2YFMUHYgE5QdyARlBzJB2YFMUHYgE5QdyATH2Uvw6quvJvPHHnssmb/88svJfObMmcn88ccfr5lxHB1HsGcHMkHZgUxQdiATlB3IBGUHMkHZgUxQdiATHGcfpf7+/prZ3XffnVx3YGAgmV988cXJfOXKlcl82rRpyRyQ2LMD2aDsQCYoO5AJyg5kgrIDmaDsQCYoO5CJ0ZyffaqkZyWdLukrSf0R0Wf7Hkn/IunIh6LfGRHrWjVoq+3cuTOZL168uGb2xRdfJNedM2dOMl++fHky5zg6yjCaN9UckvTLiHjX9rclvWN7fZE9HBEPtW48AGUZzfnZ90jaU1w/YHurpCmtHgxAuY7pZ3bb0yV9X9KfikU32d5s+ynbE2qss9B21Xa13mmQALTOqMtue5yk1ZJ+ERH7JS2X9D1JMzS05//VSOtFRH9EVCKi0tvb2/zEABoyqrLb/paGir4yIn4vSRExEBGHI+IrSU9ImtW6MQE0q27ZbVvSk5K2RsSvhy2fPOxuP5a0pfzxAJRlNK/GXyTpZ5Les72pWHanpPm2Z0gKSTsk/bwF87XN5s2bk3nq8Nr111+fXLevry+Zjxs3LpkDZRjNq/EbJXmE6Lg9pg7kiHfQAZmg7EAmKDuQCcoOZIKyA5mg7EAm+CjpwlVXXZXMI6JNkwCtwZ4dyARlBzJB2YFMUHYgE5QdyARlBzJB2YFMuJ3Hj20PShr+mc09kva1bYBj062zdetcErM1qszZzoiIET//ra1l/8bG7WpEVDo2QEK3ztatc0nM1qh2zcbTeCATlB3IRKfL3t/h7ad062zdOpfEbI1qy2wd/ZkdQPt0es8OoE0oO5CJjpTd9mW2P7C9zfYdnZihFts7bL9ne5Ptaodnecr2Xttbhi2baHu97Y+KyxHPsdeh2e6xvbt47DbZvqJDs021/UfbW22/b3txsbyjj11irrY8bm3/md32GEkfSpojaZektyXNj4j/busgNdjeIakSER1/A4btH0j6q6RnI+IfimX/JumziFhW/Ec5ISJu75LZ7pH0106fxrs4W9Hk4acZl3S1pOvUwccuMdc/qw2PWyf27LMkbYuI7RHxN0m/lTSvA3N0vYh4U9JnRy2eJ2lFcX2Fhv6xtF2N2bpCROyJiHeL6wckHTnNeEcfu8RcbdGJsk+R9Jdht3epu873HpL+YPsd2ws7PcwIJkXEHmnoH4+k0zo8z9Hqnsa7nY46zXjXPHaNnP68WZ0o+0inkuqm438XRcRMSZdLWlQ8XcXojOo03u0ywmnGu0Kjpz9vVifKvkvS1GG3vyPpkw7MMaKI+KS43CtpjbrvVNQDR86gW1zu7fA8/6+bTuM90mnG1QWPXSdPf96Jsr8t6Uzb37V9sqSfSlrbgTm+wfbY4oUT2R4r6UfqvlNRr5W0oLi+QNKLHZzla7rlNN61TjOuDj92HT/9eUS0/UvSFRp6Rf5/JP1rJ2aoMdffS/qv4uv9Ts8maZWGntb9r4aeEd0g6VRJGyR9VFxO7KLZ/kPSe5I2a6hYkzs028Ua+tFws6RNxdcVnX7sEnO15XHj7bJAJngHHZAJyg5kgrIDmaDsQCYoO5AJyg5kgrIDmfg/SP9Cn800eXAAAAAASUVORK5CYII=\n"
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "plt.imshow(images[0].numpy().squeeze(), cmap='gray_r');"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   }
  },
  {
   "cell_type": "markdown",
   "source": [
    "到这里，前期准备工作就结束了。\n",
    "\n",
    "---\n",
    "\n",
    "开始定义神经网络"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%% md\n"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "outputs": [],
   "source": [
    "class NerualNetwork(nn.Module):\n",
    "\n",
    "    def __init__(self):\n",
    "        super().__init__()\n",
    "\n",
    "        \"\"\"\n",
    "        定义第一个线性层，\n",
    "        输入为图片（28x28），\n",
    "        输出为第一个隐层的输入，大小为128。\n",
    "        \"\"\"\n",
    "        self.linear1 = nn.Linear(28 * 28, 128)\n",
    "        # 在第一个隐层使用ReLU激活函数\n",
    "        self.relu1 = nn.ReLU()\n",
    "        \"\"\"\n",
    "        定义第二个线性层，\n",
    "        输入是第一个隐层的输出，\n",
    "        输出为第二个隐层的输入，大小为64。\n",
    "        \"\"\"\n",
    "        self.linear2 = nn.Linear(128, 64)\n",
    "        # 在第二个隐层使用ReLU激活函数\n",
    "        self.relu2 = nn.ReLU()\n",
    "        \"\"\"\n",
    "        定义第三个线性层，\n",
    "        输入是第二个隐层的输出，\n",
    "        输出为输出层，大小为10\n",
    "        \"\"\"\n",
    "        self.linear3 = nn.Linear(64, 10)\n",
    "        # 最终的输出经过softmax进行归一化\n",
    "        self.softmax = nn.LogSoftmax(dim=1)\n",
    "\n",
    "        # 上述动作可以直接使用nn.Sequential写成如下形式：\n",
    "        self.model = nn.Sequential(nn.Linear(28 * 28, 128),\n",
    "                      nn.ReLU(),\n",
    "                      nn.Linear(128, 64),\n",
    "                      nn.ReLU(),\n",
    "                      nn.Linear(64, 10),\n",
    "                      nn.LogSoftmax(dim=1)\n",
    "                 )\n",
    "\n",
    "    def forward(self, x):\n",
    "        \"\"\"\n",
    "        定义神经网络的前向传播\n",
    "        x: 图片数据, shape为(64, 1, 28, 28)\n",
    "        \"\"\"\n",
    "        # 首先将x的shape转为(64, 784)\n",
    "        x = x.view(x.shape[0], -1)\n",
    "\n",
    "        # 接下来进行前向传播\n",
    "        x = self.linear1(x)\n",
    "        x = self.relu1(x)\n",
    "        x = self.linear2(x)\n",
    "        x = self.relu2(x)\n",
    "        x = self.linear3(x)\n",
    "        x = self.softmax(x)\n",
    "\n",
    "        # 上述一串，可以直接使用 x = self.model(x) 代替。\n",
    "\n",
    "        return x"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "outputs": [],
   "source": [
    "model = NerualNetwork()"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   }
  },
  {
   "cell_type": "markdown",
   "source": [
    "神经网络定义完后，开始定义损失函数，这里选用**负对数似然**损失函数（`NLLLoss`， [negative log likelihood loss](https://pytorch.org/docs/stable/generated/torch.nn.NLLLoss.html)），其常用于分类任务。[详情可参考链接](https://blog.csdn.net/qq_22210253/article/details/85229988)"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%% md\n"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "outputs": [],
   "source": [
    "criterion = nn.NLLLoss()"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   }
  },
  {
   "cell_type": "markdown",
   "source": [
    "接下来定义优化器，这里使用随机梯度下降法，学习率设置为0.003，momentum取默认的0.9（用于防止过拟合）"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%% md\n"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "outputs": [],
   "source": [
    "optimizer = optim.SGD(model.parameters(), lr=0.003, momentum=0.9)"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   }
  },
  {
   "cell_type": "markdown",
   "source": [
    "准备工作完毕，开始训练数据集："
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%% md\n"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 0 - Training loss: 0.6294137474252726\n",
      "Epoch 1 - Training loss: 0.27885234054884933\n",
      "Epoch 2 - Training loss: 0.2180362274207032\n",
      "Epoch 3 - Training loss: 0.17646600610650043\n",
      "Epoch 4 - Training loss: 0.14901786734228895\n",
      "Epoch 5 - Training loss: 0.12897429347081957\n",
      "Epoch 6 - Training loss: 0.11274210547309504\n",
      "Epoch 7 - Training loss: 0.10064082649717135\n",
      "Epoch 8 - Training loss: 0.09091206552532277\n",
      "Epoch 9 - Training loss: 0.08191311532861865\n",
      "Epoch 10 - Training loss: 0.07508732156971021\n",
      "Epoch 11 - Training loss: 0.07009464516681728\n",
      "Epoch 12 - Training loss: 0.0649078527074764\n",
      "Epoch 13 - Training loss: 0.06004000982112769\n",
      "Epoch 14 - Training loss: 0.054164604703361575\n",
      "\n",
      "Training Time (in minutes) = 0.9925608317057292\n"
     ]
    }
   ],
   "source": [
    "time0 = time() # 记录下当前时间\n",
    "epochs = 15 # 一共训练15轮\n",
    "for e in range(epochs):\n",
    "    running_loss = 0 # 本轮的损失值\n",
    "    for images, labels in train_loader:\n",
    "        # 前向传播获取预测值\n",
    "        output = model(images)\n",
    "\n",
    "        # 计算损失\n",
    "        loss = criterion(output, labels)\n",
    "\n",
    "        # 进行反向传播\n",
    "        loss.backward()\n",
    "\n",
    "        # 更新权重\n",
    "        optimizer.step()\n",
    "\n",
    "        # 清空梯度\n",
    "        optimizer.zero_grad()\n",
    "\n",
    "        # 累加损失\n",
    "        running_loss += loss.item()\n",
    "    else:\n",
    "        # 一轮循环结束后打印本轮的损失函数\n",
    "        print(\"Epoch {} - Training loss: {}\".format(e, running_loss/len(train_loader)))\n",
    "\n",
    "# 打印总的训练时间\n",
    "print(\"\\nTraining Time (in minutes) =\",(time()-time0)/60)"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   }
  },
  {
   "cell_type": "markdown",
   "source": [
    "最终在我这台机器上，花费了2分多钟完成了训练。可以看到，损失是越来越小的。"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%% md\n"
    }
   }
  },
  {
   "cell_type": "markdown",
   "source": [
    "接下来进行模型的评估"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%% md\n"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Number Of Images Tested = 10000\n",
      "\n",
      "Model Accuracy = 0.97\n"
     ]
    }
   ],
   "source": [
    "correct_count, all_count = 0, 0\n",
    "model.eval() # 将模型设置为评估模式\n",
    "\n",
    "# 从test_loader中一批一批加载图片\n",
    "for images,labels in test_loader:\n",
    "    # 循环检测这一批图片\n",
    "    for i in range(len(labels)):\n",
    "        logps = model(images[i])  # 进行前向传播，获取预测值\n",
    "        probab = list(logps.detach().numpy()[0]) # 将预测结果转为概率列表。[0]是取第一张照片的10个数字的概率列表（因为一次只预测一张照片）\n",
    "        pred_label = probab.index(max(probab)) # 取最大的index作为预测结果\n",
    "        true_label = labels.numpy()[i]\n",
    "        if(true_label == pred_label): # 判断是否预测正确\n",
    "            correct_count += 1\n",
    "        all_count += 1\n",
    "\n",
    "print(\"Number Of Images Tested =\", all_count)\n",
    "print(\"\\nModel Accuracy =\", (correct_count/all_count))"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   }
  },
  {
   "cell_type": "markdown",
   "source": [
    "最终，本次训练在测试数据集上的精准率为97.41%"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%% md\n",
     "is_executing": true
    }
   }
  },
  {
   "cell_type": "markdown",
   "source": [
    "# 参考资料\n",
    "\n",
    "[Handwritten Digit Recognition Using PyTorch — Intro To Neural Networks](https://towardsdatascience.com/handwritten-digit-mnist-pytorch-977b5338e627)： https://towardsdatascience.com/handwritten-digit-mnist-pytorch-977b5338e627"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%% md\n"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   }
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 2
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython2",
   "version": "2.7.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 0
}